"""Created on April, 4, 2022.

Optimal Policy Trees: Functions - Python implementation

Can be used under Creative Commons Licence CC BY-SA
Michael Lechner, SEW, University of St. Gallen, Switzerland

# -*- coding: utf-8 -*-
"""
from math import inf

import pandas as pd
import numpy as np

from mcf import optp_tree_functions as optp_t
from mcf import optp_tree_add_functions as optp_ta


def black_box_allocation(indata, preddata, c_dict, v_dict, seed):
    """
    Organise the estimation of the black-box (PO based) allocations.

    Parameters
    ----------
    indata : String
        Training data.
    preddata : String
        Prediction data.
    c_dict : Dict.
        Control parameters.
    v_dict : Dict
        Variables.

    Returns
    -------
    None.

    """
    rng = np.random.default_rng(seed)
    if indata != preddata:
        if v_dict['polscore_desc_name'] is None:
            vars_to_check = v_dict['polscore_name']
        else:
            vars_to_check = (v_dict['polscore_name']
                             + v_dict['polscore_desc_name'])
        use_pred_data = vars_in_data_file(preddata, vars_to_check)
    else:
        use_pred_data = False
    data_files = [indata, preddata] if use_pred_data else [indata]
    for idx, data_file in enumerate(data_files):
        if c_dict['only_if_sig_better_vs_0']:
            po_np, data_df = optp_t.adjust_policy_score(data_file, c_dict,
                                                        v_dict)
        else:
            data_df = pd.read_csv(data_file)
            po_np = data_df[v_dict['polscore_name']].to_numpy()
        if v_dict['polscore_desc_name'] is not None:
            po_descr_np = data_df[v_dict['polscore_desc_name']].to_numpy()
        else:
            po_descr_np = None
        allocation = bb_allocation(po_np, data_df, c_dict, v_dict, rng)
        if c_dict['with_output']:
            treatment = data_df[v_dict['d_name']] if idx == 0 else None
            bb_allocation_stats(allocation, po_np, c_dict, v_dict, data_file,
                                po_descr_np, treatment.to_numpy())
        black_box_alloc_pred = allocation if idx == 1 else None
    return black_box_alloc_pred


def bb_allocation(po_np_, data_df, c_dict, v_dict, rng):
    """
    Generate the various black-box allocations.

    Parameters
    ----------
    po_df : Dataframe
        Contains the relevant potential outcomes.
    data_df : Dataframe
        Contains dataframe with all variables.
    c_dict : Dict
        Controls.
    v_dict : Dict
        Variables.

    Returns
    -------
    allocation : Dict
        Dictionary of allocations.
    """
    po_np = po_np_.copy()
    no_obs, no_treat = po_np.shape[0], po_np.shape[1]
    largest_gain, random = {}, {}
    largest_gain['type'] = 'Unrestricted: Largest gain'
    largest_gain['alloc'] = np.argmax(po_np, axis=1)
    random['type'] = 'Unrestricted: random'
    random['alloc'] = rng.integers(0, high=no_treat, size=no_obs)
    if c_dict['restricted']:
        max_by_cat = np.int64(
            np.floor(no_obs * np.array(c_dict['max_shares'])))
        random_rest, largest_gain_rest = {}, {}
        random_rest['type'] = 'Restricted: random'
        random_rest['alloc'] = np.zeros_like(random['alloc'])
        so_far_by_cat = np.zeros_like(max_by_cat)
        for idx in range(no_obs):
            for _ in range(10):
                draw = rng.integers(0, high=no_treat, size=1)
                if so_far_by_cat[draw] <= max_by_cat[draw]:
                    so_far_by_cat[draw] += 1
                    random_rest['alloc'][idx] = draw
                    break
        largest_gain_rest['type'] = 'Restricted: Largest gain'
        largest_gain_rest['alloc'] = np.zeros_like(random['alloc'])
        val_best_treat = np.empty(no_obs)
        for i in range(no_obs):
            val_best_treat[i] = (po_np[i, largest_gain['alloc'][i]]
                                 - po_np[i, 0])
        order_best_treat = np.flip(np.argsort(val_best_treat))
        largest_gain_rest['alloc'] = largest_gain_rest_fct(
            order_best_treat, largest_gain['alloc'], max_by_cat, po_np.copy(),
            no_treat)
        largest_gain_rest_random_order = {}
        largest_gain_rest_random_order['type'] = (
            'Restricted: Largest gain - first come first served')
        order_random = np.arange(no_obs)
        rng.shuffle(order_random)
        largest_gain_rest_random_order['alloc'] = largest_gain_rest_fct(
            order_random, largest_gain['alloc'], max_by_cat, po_np.copy(),
            no_treat)
        if v_dict['bb_rest_variable']:
            largest_gain_rest_other_var = {}
            largest_gain_rest_other_var['type'] = (
                'Restricted: Largest gain - based on ' +
                str(*v_dict['bb_rest_variable']))
            order_other_var = np.flip(
                np.argsort(data_df[v_dict['bb_rest_variable']].to_numpy(),
                           axis=0))
            order_other_var = [x[0] for x in order_other_var]
            largest_gain_rest_other_var['alloc'] = largest_gain_rest_fct(
                order_other_var, largest_gain['alloc'], max_by_cat,
                po_np.copy(), no_treat)
            return (random, largest_gain, random_rest, largest_gain_rest,
                    largest_gain_rest_random_order,
                    largest_gain_rest_other_var)
        return (random, largest_gain, random_rest, largest_gain_rest,
                largest_gain_rest_random_order)
    return (random, largest_gain)


def largest_gain_rest_fct(order_treat, largest_gain_alloc, max_by_cat, po_np,
                          no_treat):
    """Get index of largest gain under restr. for each obs with given order."""
    def helper_largest_gain(best_last, po_np_i, so_far_by_cat, max_by_cat):
        po_np_i[best_last] = -inf
        best = np.argmax(po_np_i)
        if so_far_by_cat[best] <= max_by_cat[best]:
            so_far_by_cat[best] += 1
            success = True
        else:
            success = False
        # otherwise it remains at the zero default
        return so_far_by_cat, best, success

    so_far_by_cat = np.zeros_like(max_by_cat)
    largest_gain_rest = np.zeros_like(largest_gain_alloc)
    for i in order_treat:
        best_1 = largest_gain_alloc[i]
        if so_far_by_cat[best_1] <= max_by_cat[best_1]:
            so_far_by_cat[best_1] += 1
            largest_gain_rest[i] = best_1
        else:
            if no_treat > 2:
                so_far_by_cat, best_2, success = helper_largest_gain(
                    best_1, po_np[i], so_far_by_cat, max_by_cat)
                if success:
                    largest_gain_rest[i] = best_2
                else:
                    if no_treat > 3:
                        so_far_by_cat, best_3, success = helper_largest_gain(
                             best_2, po_np[i], so_far_by_cat, max_by_cat)
                        if success:
                            largest_gain_rest[i] = best_3
    return largest_gain_rest


def bb_allocation_stats(allocation, po_np, c_dict, v_dict, data_file,
                        po_descr_np, treatment=None):
    """
    Show descriptive stats for the various black-box allocations.

    Parameters
    ----------
    allocation : Dict
        Dictionary of allocations.
    c_dict : Dict
        Controls.
    v_dict : Dict
        Variables.

    Returns
    -------
    None.

    """
    print('=' * 80, '\n' + 'Black-Box approaches', '\n' + '-' * 80)
    if c_dict['only_if_sig_better_vs_0']:
        print('While tree building, policy scores not significantly',
              'different from zero are set to zero. Below, orginal scores',
              'are used.')
        print('Significance level used for recoding:',
              f' {c_dict["sig_level_vs_0"] * 100:6.4f} %')
    print('-' * 80, '\n')
    print(data_file)
    print('-' * 80)
    print('Policy scores: ', end=' ')
    for i in v_dict['polscore_name']:
        print(i, end=' ')
    print('\n' + '-' * 80)
    total_obs, no_of_treat = len(po_np), po_np.shape[1]
    for alloc in allocation:
        obs_by_treat = np.zeros(no_of_treat)
        for i in range(no_of_treat):
            obs_by_treat[i] = np.sum(alloc['alloc'] == i)
        print()
        print(alloc['type'])
        print('-' * 80)
        total_score = total_score_change = obs_change = 0
        changers = np.zeros(total_obs) > 10   # all elements are False
        for i, _ in enumerate(po_np):
            treat_alloc_i = alloc['alloc'][i]
            total_score += po_np[i, treat_alloc_i]
            if treatment is not None:
                if treat_alloc_i != treatment[i]:
                    total_score_change += po_np[i, treat_alloc_i]
                    obs_change += 1
                    changers[i] = True
        total_cost = np.sum(c_dict['costs_of_treat'] * obs_by_treat)
        print(f'Total score:        {total_score:14.4f} ',
              f'  Average score:        {total_score / total_obs:14.4f}')
        print(f'Total cost:         {total_cost:14.4f} ',
              f'  Average cost:         {total_cost / total_obs:14.4f}')
        print(f'Total score - cost: {total_score-total_cost:14.4f} ',
              '  Average score - cost:',
              f'{(total_score-total_cost) / total_obs:14.4f}')
        print('- ' * 40)
        print(f'Total number of observations: {int(total_obs):d}')
        print('Treatments:                           ', end=' ')
        for i, _ in enumerate(v_dict['polscore_name']):
            print(f'{i:6d} ', end=' ')
        print('\nObservations allocated, by treatment: ', end=' ')
        for i in obs_by_treat:
            print(f'{int(i):6d} ', end=' ')
        print('\nObservations allowed,   by treatment: ', end=' ')
        for i in c_dict['max_by_treat']:
            print(f'{int(i):6d} ', end=' ')
        print('\nCost per treatment:                   ', end=' ')
        for i in c_dict['costs_of_treat']:
            print(f'{i:6.2f} ', end=' ')
        print('\n' + '-' * 80)
        if treatment is not None:
            print(f'Number of changers: {obs_change:d}')
            print(f'Total score (changers): {total_score_change:12.4f} ',
                  ' Average score (changers):',
                  f' {total_score_change / obs_change:12.4f}')
            print('-' * 80)
        if po_descr_np is not None:
            optp_ta.describe_alloc_other_outcomes(
                v_dict['polscore_desc_name'], po_descr_np, no_of_treat,
                alloc['alloc'], changers)


def vars_in_data_file(preddata, var_list):
    """Check if potential outcomes are in prediction data."""
    headers = pd.read_csv(filepath_or_buffer=preddata, nrows=0)
    header_list = [s.upper() for s in list(headers.columns.values)]
    var_in_pred = all(i in header_list for i in var_list)
    return var_in_pred
